/****************************************************************************
** This code creates a correlation matrix with the ORIV procedure			*
**																			*
** A variant of the ORIV procedure is created to deal with the 				*
** case with non-orthogonal measurement error. This is applied to produce	*
** a correlation matrix that is stored in STATA memory, and then can be		*
** output or analyzed														*
**																			*
** Produced by Erik Snowberg, October 13, 2022								*
** 		Based on Code by Jonathan Chapman and Erik Snowberg					*
****************************************************************************/

program drop _all
matrix drop _all

/******************************************************************************
** 1) Inputs from file calling the code
******************************************************************************/
local twoByTwoList="$twoByTwoList"
local oneByTwoListEV="$oneByTwoListEV"
local oneByTwoListNotEV="$oneByTwoListNotEV"
local twoByTwoListNonOrthog="$twoByTwoListNonOrthog"

tempfile correlationsTemp
save `correlationsTemp'

macro drop idName correlationFile correlationDirectory oneByTwoListNotEV oneByTwoListEV twoByTwoList twoByTwoListNonOrthog weightName

// Combine lists of variables with one measure for coding purposes
local  oneByTwoListComb `oneByTwoListEV' `oneByTwoListNotEV'

/*******************************************************************************************************
** 2) Define ORIV program to remove non-orthogonal elements
*******************************************************************************************************/
program ORIVCorrelationNonOrthog, rclass
	preserve // So original dataset is restored
	   
   *** (0) Syntax handling, misc
   syntax [if] [pweight /], GROUP1(varlist numeric) GROUP2(varlist numeric) [equalvar  BSREPS(integer 10000)]
   
   ** Keep only the data that satisfies the [if] option
   if "`if'" != "" {
	  quietly keep `if'
   }

   ** Generate weight macros
   if "`weight'" != "" {
	  gen __weightvar = `exp'
	  local pw_exp =  "[pweight=__weightvar]" // used for ivreg
	  local aw_exp =  "[aweight=__weightvar]" // used for summarize,corr
   }
   else {
	  local pw_exp = ""
	  local aw_exp = ""
	  quietly gen __weightvar = .
   }

	** Get original sample size
	local N = _N
   ** Count, rename, and standardize the input variables (group 1)
   local input_n1 = 0
   foreach var in `group1' { 
	  local ++input_n1
	  keep if ~missing(`var')
	  quietly rename `var' __input1_`input_n1'
	  quietly summarize __input1_`input_n1' `aw_exp'
	  quietly replace __input1_`input_n1' = (__input1_`input_n1'-r(mean))/r(sd)
   }

   ** Count, rename, and standardize the input variables (group 2)
   local input_n2 = 0
   foreach var in `group2' { 
      local ++input_n2
      keep if ~missing(`var')
      quietly rename `var' __input2_`input_n2'
      quietly summarize __input2_`input_n2' `aw_exp'
      quietly replace __input2_`input_n2' = (__input2_`input_n2'-r(mean))/r(sd)
   }
	   
	*** Constructing the Long Dataset with two copies of each
	quietly keep __input1_* __input2_*  __weightvar
	quietly generate id = _n
	
	quietly expand 2, gen(replicant)
	
	***  Define the LHS,RHS variables	
	quietly {
		gen LHS = __input1_1 if replicant == 0
		replace LHS = __input1_2 if replicant == 1
		gen mainVar = __input2_1 if replicant == 0
		replace mainVar = __input2_2 if replicant == 1
		gen instrument = __input2_2 if replicant == 0
		replace instrument = __input2_1 if replicant == 1
	}
	forvalues x=0/1 {
			quietly gen constant`x' = (replicant == `x')
	}

	//Point Estimate
	quietly ivregress 2sls LHS (mainVar = instrument) constant* `pw_exp', cluster(id) nocons
	local correctedCoefficient = _b[mainVar]
	local asymp_se = _se[mainVar]
	qui corr __input1_* `aw_exp' if replicant == 0 , cov
	local correctedYVar = r(cov_12)
	qui corr __input2_* `aw_exp'  if replicant == 0, cov
	local correctedXVar = r(cov_12)
	local correctedCorrelation = `correctedCoefficient'*sqrt(`correctedXVar'/`correctedYVar')
	local correctedAsympSE = `asymp_se'*sqrt(`correctedXVar'/`correctedYVar')

	//Bootstrap
	qui d
	local obs = r(N)

	compress
	sort id replicant
	tempfile dataFile
	save `dataFile'

	tempname memHold
	tempfile rhoFile
	postfile `memHold' correctedRho using `rhoFile', replace

	forvalues i = 1/`bsreps' {
		clear
		qui set obs `obs'
		qui gen id = ceil(runiform()*`obs')
		//this is faster, if not a bit of a hack
		qui expand 2, generate(replicant)
		sort id replicant

		qui merge n:1 id replicant using `dataFile'
		qui drop if _merge == 2
		qui drop  _merge

		// Do not cluster in bootstrap, since only affects SEs, which we do not care about
		quietly ivregress 2sls LHS (mainVar = instrument) constant* `pw_exp' , nocons
		local correctedCoefficient = _b[mainVar]

		qui drop if replicant ~= 0
		qui corr __input1_* `aw_exp' if replicant == 0 , cov
		local correctedYVar = r(cov_12)
		qui corr __input1_* `aw_exp' if replicant == 0, cov

		post `memHold' (`correctedCoefficient'*sqrt(r(cov_12)/`correctedYVar'))
	}
	postclose `memHold'

	**********************
	*** (4) CALCULATE SE
	**********************
   quietly summarize id
   local N = r(max)

   if `bsreps'>0 {
	use `rhoFile', clear
	qui sum correctedRho
	local correctedBSSE = r(sd)
	local se = `correctedBSSE'
      
   }
   else {
      local se = `correctedAsympSE'
   }

	*********************
	*** (5) REPORT & RETURN
	*********************
	* Return the estimated correlation
	return scalar corr = `correctedCorrelation'

	* Return se
	return scalar se = `se'

	display "The ORIV Correlation is: " string(`correctedCorrelation' ,"%9.8f")
	   if `bsreps'>0 {
      display "The Bootstrapped Standard Error of the ORIV Correlation is : " string(`se',"%9.8f") 
   }
   else {
      display "   The asymptotic Standard Error of the ORIV Correlation is : " string(`se',"%9.8f")
      display "   (Note: not valid for testing hypotheses other than rho = 0)"
   }
	
	restore 
end 

/*******************************************************************************************************
** 3) Correlation matrix: Weighted ORIV
*******************************************************************************************************/
** Create the three matrices: one for correlations, standard errors and p-values
local numTwoByTwo : word count `twoByTwoList'
local numOneByTwoListEV : word count `oneByTwoListEV'
local numOneByTwoListNotEV  : word count `oneByTwoListNotEV'
di `numTwoByTwo'
di `numOneByTwoListEV'
di `numOneByTwoListEV'
foreach i in corr se { 
		matrix define `i'=J(`numTwoByTwo'+`numOneByTwoListEV'+`numOneByTwoListNotEV',`numTwoByTwo'+`numOneByTwoListEV'+`numOneByTwoListNotEV',.)
		matrix rownames `i'=`twoByTwoList' `oneByTwoListEV' `oneByTwoListNotEV'
		matrix colnames `i'=`twoByTwoList' `oneByTwoListEV' `oneByTwoListNotEV'
}

** i) correlations where both variables have 2 measures
foreach i in `twoByTwoList' {
	local k "`i'"
	local  tempList: list twoByTwoList-k
	** macro list _tempList
	foreach j in `tempList' {
			di "`i'"
			di "`j'"
			ORIVCorrelation [pw=weight], group1(`i'1 `i'2) group2(`j'1 `j'2) equalvar bsreps($bootstrapReps)			
			foreach m in corr se { 
				mat `m'[rownumb(`m',"`i'"),colnumb(`m',"`j'")]= round(`r(`m')',0.00000001)
			}
	}		
}
** ii) Correlations 1x2, with equal variance assumption
foreach i in `oneByTwoListEV' {
	foreach j in `twoByTwoList' {
		di "`i'"
		di "`j'"
		ORIVCorrelation [pw=weight], group1(`i') group2(`j'1 `j'2) equalvar bsreps($bootstrapReps)
		foreach m in corr se { 
			mat `m'[rownumb(`m',"`i'"),colnumb(`m',"`j'")]= round(`r(`m')',0.00000001)
			mat `m'[rownumb(`m',"`j'"),colnumb(`m',"`i'")]= round(`r(`m')',0.00000001)
		}
	}		
}
** iii) Correlations 1x2, without equal variance assumption
foreach i in `oneByTwoListNotEV' {
	foreach j in `twoByTwoList' {
			di "`i'"
			di "`j'"
			ORIVCorrelation [pw=weight], group1(`i') group2(`j'1 `j'2) bsreps($bootstrapReps)
			foreach m in corr se { 
				mat `m'[rownumb(`m',"`i'"),colnumb(`m',"`j'")]= round(`r(`m')',0.00000001)
				mat `m'[rownumb(`m',"`j'"),colnumb(`m',"`i'")]= round(`r(`m')',0.00000001)
			}
	}		
}
** iv) Correlations with only one measure for each variable 1x1 
foreach i in `oneByTwoListComb' {
	local k "`i'"
	local  tempList: list oneByTwoListComb-k	
	foreach j in `tempList' {
			di "`i'"
			di "`j'"
			ORIVCorrelation [pw=weight], group1(`i') group2(`j') bsreps($bootstrapReps)
			foreach m in corr se { 
				mat `m'[rownumb(`m',"`i'"),colnumb(`m',"`j'")]= round(`r(`m')',0.00000001)
			}		
	}		
}

** v) Correlations with non-orthogonal element of measurement error
local numTwoByTwoListNonOrthog  : word count `twoByTwoListNonOrthog'
di `numTwoByTwoListNonOrthog'
local y=`numTwoByTwoListNonOrthog'-1
forvalues x=1(2)`y' {
	local i : word `x' of `twoByTwoListNonOrthog'
	local z=`x'+1 
	local j : word `z' of `twoByTwoListNonOrthog'
	di "`i'"
	di "`j'"
	ORIVCorrelationNonOrthog [pw=weight], group1(`i'1 `i'2) group2(`j'1 `j'2) bsreps($bootstrapReps)
	foreach m in corr se { 
		mat `m'[rownumb(`m',"`i'"),colnumb(`m',"`j'")]= round(`r(`m')',0.00000001)
	}			
	** Repeat with order of variables reversed
	ORIVCorrelationNonOrthog [pw=weight], group1(`j'1 `j'2) group2(`i'1 `i'2) bsreps($bootstrapReps)
	foreach m in corr se { 
		mat `m'[rownumb(`m',"`j'"),colnumb(`m',"`i'")]= round(`r(`m')',0.00000001)
	}
}


// Need to replace the diagonals with ones for pcamat
mata
	row=st_matrixrowstripe("corr")
	col=st_matrixcolstripe("corr")
	corrMatrix=st_matrix("corr")
	_diag(corrMatrix,1)
	st_matrix("corr2",corrMatrix)
	st_matrixrowstripe("corr2",row)
	st_matrixcolstripe("corr2",col)
end
matrix drop corr

use `correlationsTemp', clear
